{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ogushXV4ZGMi"
      },
      "source": [
        "# Fine-Tuning an OpenSource Model using QLoRA"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "chQmy4_HXhgr"
      },
      "outputs": [],
      "source": [
        "%pip install -qU peft trl bitsandbytes datasets wandb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zHZJoUeQZJNo"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers import (\n",
        "    AutoModelForCausalLM,\n",
        "    AutoTokenizer,\n",
        "    BitsAndBytesConfig,\n",
        "    TrainingArguments\n",
        ")\n",
        "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
        "from trl import SFTConfig, SFTTrainer\n",
        "from datasets import load_dataset"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "MODEL_NAME = \"mistralai/Mistral-7B-v0.1\"\n",
        "\n",
        "\n",
        "bnb_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16\n",
        ")"
      ],
      "metadata": {
        "id": "NyDHu1vrZ1gO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    MODEL_NAME,\n",
        "    quantization_config=bnb_config,\n",
        "    device_map=\"auto\",\n",
        "    trust_remote_code=True\n",
        ")\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
        "tokenizer.pad_token = tokenizer.eos_token\n",
        "tokenizer.padding_side = \"right\"\n",
        "\n",
        "\n",
        "model = prepare_model_for_kbit_training(model)"
      ],
      "metadata": {
        "id": "v2X2xM-dZ7fN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "lora_config = LoraConfig(\n",
        "    r=8,\n",
        "    lora_alpha=16,\n",
        "    target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"],\n",
        "    lora_dropout=0.05,\n",
        "    bias=\"none\",\n",
        "    task_type=\"CAUSAL_LM\"\n",
        ")\n",
        "\n",
        "model = get_peft_model(model, lora_config)\n",
        "model.print_trainable_parameters()"
      ],
      "metadata": {
        "id": "akK6pJnLaeIr"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "DATA_DIR = '/content/data'\n",
        "\n",
        "data_files = [\n",
        "    f'{DATA_DIR}/all_beauty_train.parquet',\n",
        "]\n",
        "\n",
        "dataset = load_dataset('parquet', data_files=data_files, split='train')\n",
        "\n",
        "train_test = dataset.train_test_split(train_size=100, test_size=20, seed=42)\n",
        "train_dataset = train_test[\"train\"]\n",
        "\n",
        "test_dataset = load_dataset('parquet', data_files=[f'{DATA_DIR}/all_beauty_test.parquet'], split='train')"
      ],
      "metadata": {
        "id": "Nm1O_IjBbNa0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "sft_config = SFTConfig(\n",
        "    output_dir=\"./price-prediction-qlora\",\n",
        "    num_train_epochs=1,\n",
        "    per_device_train_batch_size=4,\n",
        "    gradient_accumulation_steps=2,\n",
        "    gradient_checkpointing=True,\n",
        "    optim=\"paged_adamw_8bit\",\n",
        "    learning_rate=2e-4,\n",
        "    lr_scheduler_type=\"cosine\",\n",
        "    warmup_steps=50,\n",
        "    logging_steps=10,\n",
        "    save_strategy=\"no\",\n",
        "    fp16=False,\n",
        "    bf16=True,\n",
        "    max_grad_norm=0.3,\n",
        "    save_total_limit=2,\n",
        "    group_by_length=True,\n",
        "    report_to=\"none\",\n",
        "    packing=False,\n",
        "    dataset_text_field=\"text\",\n",
        ")\n",
        "\n",
        "\n",
        "trainer = SFTTrainer(\n",
        "    model=model,\n",
        "    args=sft_config,\n",
        "    train_dataset=train_dataset,\n",
        ")\n",
        "\n"
      ],
      "metadata": {
        "id": "a8Nx8GJTb-Wz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "### Start Training"
      ],
      "metadata": {
        "id": "UmH5E6Xvn8so"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Starting training...\")\n",
        "\n",
        "trainer.train()\n",
        "\n",
        "trainer.model.save_pretrained(\"./price-prediction-final\")\n",
        "tokenizer.save_pretrained(\"./price-prediction-final\")\n",
        "\n",
        "print(\"Training complete! LoRA adapters saved to ./price-prediction-final\")"
      ],
      "metadata": {
        "id": "-nXZ5O_ifFVh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "def predict_price_inmemory(prompt, model, tokenizer):\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
        "\n",
        "    with torch.no_grad():\n",
        "        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):\n",
        "          outputs = model.generate(\n",
        "              **inputs,\n",
        "              max_new_tokens=10,\n",
        "              temperature=0.1,\n",
        "              do_sample=False,\n",
        "              pad_token_id=tokenizer.eos_token_id\n",
        "          )\n",
        "\n",
        "    result = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "\n",
        "    if \"Price is $\" in result:\n",
        "        predicted = result.split(\"Price is $\")[-1].strip()\n",
        "\n",
        "        import re\n",
        "        match = re.search(r'(\\d+\\.?\\d*)', predicted)\n",
        "        if match:\n",
        "            return match.group(1)\n",
        "    return predicted"
      ],
      "metadata": {
        "id": "LSvPhf-3fYaZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Eye-test Validation\n",
        "\n",
        "Not the best I know, but I wanted to go through the entire process myself and not enough time on my hands."
      ],
      "metadata": {
        "id": "m6L5CET_sXmx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "for item in test_dataset.take(5):\n",
        "  prompt = item[\"text\"]\n",
        "  actual_price = item[\"price\"]\n",
        "\n",
        "  predicted_price = float(predict_price_inmemory(prompt, model, tokenizer))\n",
        "  print(\"\\n\" + \"*\" * 80)\n",
        "  print(prompt)\n",
        "\n",
        "  print(f\"Prediction: ${predicted_price}.  Actual: ${actual_price}. Diff {abs(predicted_price - actual_price):,.2f}\")\n"
      ],
      "metadata": {
        "id": "6pSYhLn_kROQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Loading Somewhere in the future\n",
        "\n",
        "It can even be loaded in a different notebook."
      ],
      "metadata": {
        "id": "OnJAD7YihyAD"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
        "from peft import PeftModel\n",
        "import torch\n",
        "\n",
        "bnb_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16\n",
        ")\n",
        "\n",
        "base_model = AutoModelForCausalLM.from_pretrained(\n",
        "    MODEL_NAME,\n",
        "    quantization_config=bnb_config,\n",
        "    device_map=\"auto\",\n",
        "    trust_remote_code=True\n",
        ")\n",
        "\n",
        "\n",
        "model = PeftModel.from_pretrained(base_model, \"./price-prediction-final\")\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"./price-prediction-final\")"
      ],
      "metadata": {
        "id": "5RCClHQHijes"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}